{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np \n",
    "import xml.etree.ElementTree as ET\n",
    "from scipy.io import loadmat\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from PIL import Image\n",
    "from tqdm.notebook import tqdm\n",
    "from sklearn import metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "NVIDIA GeForce GTX 1070\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(torch.cuda.is_available())\n",
    "print(torch.cuda.get_device_name(0))\n",
    "\n",
    "def load_dataset():\n",
    "\n",
    "    # Get train list\n",
    "    f = loadmat(\"lists/train_list.mat\")\n",
    "    train_images = [x[0][0] for x in f['file_list']]\n",
    "    train_labels = [x[0] for x in f['labels']]\n",
    "\n",
    "    # Get file list\n",
    "    f = loadmat(\"lists/test_list.mat\")\n",
    "    test_images = [x[0][0] for x in f['file_list']]\n",
    "    test_labels = [x[0] for x in f['labels']]\n",
    "\n",
    "    # Gather data\n",
    "    train_data = []\n",
    "    test_data = []\n",
    "\n",
    "    # Record category ids\n",
    "    categories = {}\n",
    "\n",
    "    for i in range(len(train_images) + len(test_images)):\n",
    "\n",
    "        # Determine if train or test\n",
    "        image = train_images[i] if i < len(train_images) else test_images[i - len(train_images)]\n",
    "        label = train_labels[i] if i < len(train_images) else test_labels[i - len(train_images)]\n",
    "        label_name = os.path.split(image)[0]\n",
    "        # Label -1 to make it 0-indexed\n",
    "        categories[label_name] = label-1\n",
    "        annotation_path = os.path.join(\"Annotation\", image.replace(\".jpg\", \"\"))\n",
    "\n",
    "        # Read XML\n",
    "        tree = ET.parse(annotation_path)\n",
    "        root = tree.getroot()\n",
    "\n",
    "        width = int(root.find(\"size\").find(\"width\").text)\n",
    "        height = int(root.find(\"size\").find(\"height\").text)\n",
    "\n",
    "        bndbox = root.find(\"object\").find(\"bndbox\")\n",
    "        xmin = int(bndbox.find(\"xmin\").text)\n",
    "        ymin = int(bndbox.find(\"ymin\").text)\n",
    "        xmax = int(bndbox.find(\"xmax\").text)\n",
    "        ymax = int(bndbox.find(\"ymax\").text)\n",
    "\n",
    "        # Append to data\n",
    "        if i < len(train_images):\n",
    "            train_data.append(dict(\n",
    "                image=os.path.join(\"Images\", image),\n",
    "                label=label-1,\n",
    "                label_name=label_name,\n",
    "                width=width,\n",
    "                height=height,\n",
    "                xmin=xmin,\n",
    "                ymin=ymin,\n",
    "                xmax=xmax,\n",
    "                ymax=ymax\n",
    "            ))\n",
    "        else:\n",
    "            test_data.append(dict(\n",
    "                image=os.path.join(\"Images\", image),\n",
    "                label=label-1,\n",
    "                label_name=label_name,\n",
    "                width=width,\n",
    "                height=height,\n",
    "                xmin=xmin,\n",
    "                ymin=ymin,\n",
    "                xmax=xmax,\n",
    "                ymax=ymax\n",
    "            ))\n",
    "\n",
    "\n",
    "    return train_data, test_data, categories\n",
    "\n",
    "# Read dataset and gather into dataframe\n",
    "train_data, test_data, categories = load_dataset()\n",
    "dftrain = pd.DataFrame(train_data)\n",
    "dftest = pd.DataFrame(test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of classes:  120\n",
      "Number of training samples:  12000\n",
      "Number of testing samples:  4290\n",
      "Number of validation samples:  4290\n"
     ]
    }
   ],
   "source": [
    "# Get the classes summary\n",
    "print(\"Number of classes: \", len(categories))\n",
    "print(\"Number of training samples: \", len(dftrain))\n",
    "print(\"Number of testing samples: \", len(dftest)//2)\n",
    "print(\"Number of validation samples: \", len(dftest)//2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[[ 1.6495,  1.6153,  1.5639,  ..., -0.9363, -0.9363, -0.9363],\n",
       "          [ 1.6667,  1.6324,  1.5810,  ..., -0.9363, -0.9363, -0.9363],\n",
       "          [ 1.6838,  1.6495,  1.5982,  ..., -0.9363, -0.9363, -0.9363],\n",
       "          ...,\n",
       "          [-1.2788, -1.2788, -1.2617,  ..., -0.0287, -0.0801, -0.1143],\n",
       "          [-1.4672, -1.4672, -1.4843,  ..., -0.0116, -0.0629, -0.0972],\n",
       "          [-1.5528, -1.5699, -1.5870,  ..., -0.0116, -0.0458, -0.0801]],\n",
       " \n",
       "         [[ 1.0280,  0.9930,  0.9405,  ..., -1.0903, -1.0903, -1.0903],\n",
       "          [ 1.0455,  1.0105,  0.9580,  ..., -1.0903, -1.0903, -1.0903],\n",
       "          [ 1.0630,  1.0280,  0.9755,  ..., -1.0903, -1.0903, -1.0903],\n",
       "          ...,\n",
       "          [ 0.3452,  0.3803,  0.4328,  ..., -0.1099, -0.0924, -0.0749],\n",
       "          [ 0.1877,  0.1877,  0.2052,  ..., -0.0924, -0.0749, -0.0574],\n",
       "          [ 0.1001,  0.1001,  0.1001,  ..., -0.0924, -0.0574, -0.0399]],\n",
       " \n",
       "         [[ 0.8797,  0.8448,  0.7925,  ..., -1.5430, -1.5430, -1.5430],\n",
       "          [ 0.8971,  0.8622,  0.8099,  ..., -1.5430, -1.5430, -1.5430],\n",
       "          [ 0.9145,  0.8797,  0.8274,  ..., -1.5430, -1.5430, -1.5430],\n",
       "          ...,\n",
       "          [ 2.0474,  2.0648,  2.0823,  ...,  0.2871,  0.2696,  0.2696],\n",
       "          [ 1.8557,  1.8383,  1.8208,  ...,  0.3045,  0.2871,  0.2871],\n",
       "          [ 1.7511,  1.7337,  1.6988,  ...,  0.3045,  0.3045,  0.3045]]]),\n",
       " 0)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Inherit from Dataset\n",
    "class CustomDataset(Dataset):\n",
    "\n",
    "    def __init__(self, df, transform=None):\n",
    "        self.df = df\n",
    "        self.transform = transform\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.df)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        row = self.df.iloc[idx]\n",
    "        image = Image.open(row['image'])\n",
    "        image = image.convert('RGB')\n",
    "        label = row['label']\n",
    "        if self.transform:\n",
    "            image = self.transform(image)\n",
    "        return image, label\n",
    "        \n",
    "#Pre-Processing\n",
    "train_transforms = transforms.Compose([\n",
    "    # Randomly resize and crop the image to 224\n",
    "    transforms.RandomResizedCrop(224),\n",
    "    # Randomly flip the image horizontally\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    # Convert the image to a PyTorch Tensor\n",
    "    transforms.ToTensor(),\n",
    "    # Normalize the image\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
    "])\n",
    "\n",
    "test_transforms = transforms.Compose([\n",
    "    # Resize the image to 256\n",
    "    transforms.Resize(256),\n",
    "    # Crop the center of the image\n",
    "    transforms.CenterCrop(224),\n",
    "    # Convert the image to a PyTorch Tensor\n",
    "    transforms.ToTensor(),\n",
    "    # Normalize the image\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
    "])\n",
    "\n",
    "train_dataset = CustomDataset(dftrain, transform=train_transforms)\n",
    "#Split test set in half to test and validation sets\n",
    "test_dataset = CustomDataset(dftest.iloc[0:4290], transform=test_transforms)\n",
    "validation_dataset = CustomDataset(dftest.iloc[4290:], transform=test_transforms)\n",
    "\n",
    "# Test\n",
    "train_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models import resnet18\n",
    "\n",
    "# Load the pre-trained model\n",
    "model = resnet18(weights=None)\n",
    "model.load_state_dict(torch.load(\"resnet18-f37072fd.pth\"))\n",
    "\n",
    "#Freeze layers\n",
    "\"\"\"\n",
    "ct = 0\n",
    "for child in model.children():\n",
    "    ct += 1\n",
    "    if ct > 4:\n",
    "        break\n",
    "    for param in child.parameters():\n",
    "        param.requires_grad = False\"\"\"\n",
    "\n",
    "# Change the output layer\n",
    "model.fc = torch.nn.Linear(512, len(categories))\n",
    "model = model.to(device)\n",
    "\n",
    "\n",
    "\n",
    "def train(epochs = 12, batches = 32, learing_rate = 1e-4):\n",
    "\n",
    "    # Define the loss function\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr= learing_rate)\n",
    "    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)\n",
    "    \n",
    "    # Define the data loaders\n",
    "    train_loader = DataLoader(train_dataset, batch_size= batches, shuffle=True)\n",
    "    validation_loader = DataLoader(test_dataset, batch_size= batches, shuffle=False)\n",
    "    \n",
    "    # Train the model\n",
    "    model.train()\n",
    "    \n",
    "    epoch_accuracy = []\n",
    "    for epoch in range(epochs):\n",
    "        for images, labels in tqdm(train_loader, desc=\"Train Epoch \" + str(epoch)):\n",
    "            images = images.to(device)\n",
    "            labels = labels.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(images)\n",
    "            loss = criterion(outputs, labels.long())\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "    \n",
    "        # Evaluate the model using validation set\n",
    "        model.eval()\n",
    "        y_true = []\n",
    "        y_pred = []\n",
    "        with torch.no_grad():\n",
    "            for images, labels in tqdm(validation_loader, desc=\"Validation Epoch \" + str(epoch)):\n",
    "                images = images.to(device)\n",
    "                labels = labels.to(device)\n",
    "                outputs = model(images)\n",
    "                _, predicted = torch.max(outputs, 1)\n",
    "                y_true.extend(labels.cpu().numpy())\n",
    "                y_pred.extend(predicted.cpu().numpy())\n",
    "        model.train()\n",
    "        accuracy = metrics.accuracy_score(y_true, y_pred)\n",
    "        print(\"Epoch: \", epoch)\n",
    "        print(\"Accurarcy: \", accuracy)\n",
    "        print(\"F1 Score: \", metrics.f1_score(y_true, y_pred, average='macro'))\n",
    "        print(\"Confusion Matrix: \", metrics.confusion_matrix(y_true, y_pred))\n",
    "        epoch_accuracy.append(accuracy)\n",
    "\n",
    "    return epoch_accuracy , accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_params(epoch = False,batch = False,learnr = False):\n",
    "    if epoch:\n",
    "        #evaluate best epoch sizes\n",
    "        ep1 , _ = train(epochs = 40)\n",
    "        print(ep1)\n",
    "        \n",
    "        #Plot the results\n",
    "        plt.clf()\n",
    "        plt.plot(range(len(ep1)),ep1)\n",
    "        plt.title(\"Epoch Size vs Accuracy\")\n",
    "        plt.xlabel(\"Epoch Size\")\n",
    "        plt.ylabel(\"Accuracy\")\n",
    "        plt.savefig(\"epochs.png\")\n",
    "\n",
    "    if batch:\n",
    "        # evaluate best batch sizes\n",
    "        batch_sizes = [10,32,64,128]\n",
    "        batch_accuracy = []\n",
    "        for size in batch_sizes:\n",
    "            _ ,acc = train(batches = size)\n",
    "            batch_accuracy.append(acc)\n",
    "\n",
    "        #Plot the results\n",
    "        plt.clf()\n",
    "        plt.plot(batch_sizes,batch_accuracy)\n",
    "        plt.title(\"Batch Size vs Accuracy\")\n",
    "        plt.xlabel(\"batch Size\")\n",
    "        plt.ylabel(\"Accuracy\")\n",
    "        plt.savefig(\"batch.png\")\n",
    "\n",
    "    if learnr:\n",
    "        # evaluate learning rates\n",
    "        learning_rates = [1e-3,1e-4,1e-5,1e-6]\n",
    "        lr_accuracy = []\n",
    "        for lr in learning_rates:\n",
    "            _ ,acc = train(learing_rate = lr)\n",
    "            lr_accuracy.append(acc)\n",
    "\n",
    "        #Plot the results\n",
    "        plt.clf()\n",
    "        plt.plot(learning_rates,lr_accuracy)\n",
    "        plt.title(\"LR vs Accuracy\")\n",
    "        plt.xlabel(\"Learning Rate\")\n",
    "        plt.ylabel(\"Accuracy\")\n",
    "        plt.savefig(\"learning_rates.png\")\n",
    "\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run in order to comapre hyperparamters contained within eval_parmas: select paramters to evaluate by setting values to True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "34fdd026cf72431cb65d30b829fc7630",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train Epoch 0:   0%|          | 0/1200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\mehdu\\AppData\\Local\\Programs\\Python\\Python311\\Lib\\site-packages\\torch\\autograd\\graph.py:744: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ..\\aten\\src\\ATen\\native\\cudnn\\Conv_v8.cpp:919.)\n",
      "  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d0a6ec118a774bb78a0b131eec9f18ff",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation Epoch 0:   0%|          | 0/429 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:  0\n",
      "Accurarcy:  0.5841491841491842\n",
      "F1 Score:  0.2869756977195615\n",
      "Confusion Matrix:  [[21  0  0 ...  0  1  0]\n",
      " [ 0 50  0 ...  0  0  0]\n",
      " [ 0  0 86 ...  0  0  0]\n",
      " ...\n",
      " [ 0  0  0 ...  0  0  0]\n",
      " [ 0  0  0 ...  0  0  0]\n",
      " [ 0  0  0 ...  0  0  0]]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e462c2c6590d4d5aa42d7e1cf3bb0664",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train Epoch 1:   0%|          | 0/1200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1d2e47929f1c4039bfeed10e501a691f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation Epoch 1:   0%|          | 0/429 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:  1\n",
      "Accurarcy:  0.6461538461538462\n",
      "F1 Score:  0.3264267874441229\n",
      "Confusion Matrix:  [[ 32   0   0 ...   2   0   0]\n",
      " [  0  67   0 ...   0   0   0]\n",
      " [  1   0 104 ...   0   0   0]\n",
      " ...\n",
      " [  0   0   0 ...   0   0   0]\n",
      " [  0   0   0 ...   0   0   0]\n",
      " [  0   0   0 ...   0   0   0]]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "be2120ee5f194630b6a61802528cb6b8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train Epoch 2:   0%|          | 0/1200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "709b115a95b74324b26105e3cdf14662",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation Epoch 2:   0%|          | 0/429 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:  2\n",
      "Accurarcy:  0.6983682983682984\n",
      "F1 Score:  0.3303270579608136\n",
      "Confusion Matrix:  [[ 16   0   0 ...   0   0   0]\n",
      " [  0  72   0 ...   0   0   0]\n",
      " [  0   0 113 ...   0   0   0]\n",
      " ...\n",
      " [  0   0   0 ...   0   0   0]\n",
      " [  0   0   0 ...   0   0   0]\n",
      " [  0   0   0 ...   0   0   0]]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cebbc0d8ac08497494033736bd0a9894",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train Epoch 3:   0%|          | 0/1200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e15309330da3454f8a7a32ac46e1436d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation Epoch 3:   0%|          | 0/429 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:  3\n",
      "Accurarcy:  0.6846153846153846\n",
      "F1 Score:  0.32580656777692185\n",
      "Confusion Matrix:  [[ 40   0   0 ...   1   0   0]\n",
      " [  0  73   0 ...   0   0   0]\n",
      " [  0   0 105 ...   0   0   0]\n",
      " ...\n",
      " [  0   0   0 ...   0   0   0]\n",
      " [  0   0   0 ...   0   0   0]\n",
      " [  0   0   0 ...   0   0   0]]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9812633f42e6420e9058dd713573ba6b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train Epoch 4:   0%|          | 0/1200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "993518c5c2734d62a9f5ce2c6246e3f1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation Epoch 4:   0%|          | 0/429 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:  4\n",
      "Accurarcy:  0.7046620046620047\n",
      "F1 Score:  0.3355463549146909\n",
      "Confusion Matrix:  [[ 28   0   0 ...   0   0   0]\n",
      " [  0  73   0 ...   0   0   0]\n",
      " [  0   1 125 ...   0   0   0]\n",
      " ...\n",
      " [  0   0   0 ...   0   0   0]\n",
      " [  0   0   0 ...   0   0   0]\n",
      " [  0   0   0 ...   0   0   0]]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4587942505234ef0abadbd3b348e20e4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train Epoch 5:   0%|          | 0/1200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "821a8ec54e6c4e72993922bc0883d71c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation Epoch 5:   0%|          | 0/429 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:  5\n",
      "Accurarcy:  0.6944055944055944\n",
      "F1 Score:  0.3335407846941762\n",
      "Confusion Matrix:  [[ 31   0   0 ...   0   1   0]\n",
      " [  0  75   0 ...   0   0   0]\n",
      " [  1   1 117 ...   0   0   0]\n",
      " ...\n",
      " [  0   0   0 ...   0   0   0]\n",
      " [  0   0   0 ...   0   0   0]\n",
      " [  0   0   0 ...   0   0   0]]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6fb4a990cfd04744af9fcd3f73c00564",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train Epoch 6:   0%|          | 0/1200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "eb7a959fcaed41d4a295ffb9069e1832",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation Epoch 6:   0%|          | 0/429 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:  6\n",
      "Accurarcy:  0.7116550116550117\n",
      "F1 Score:  0.3589262847198806\n",
      "Confusion Matrix:  [[ 24   0   0 ...   2   0   0]\n",
      " [  0  74   0 ...   0   0   0]\n",
      " [  0   1 106 ...   0   0   0]\n",
      " ...\n",
      " [  0   0   0 ...   0   0   0]\n",
      " [  0   0   0 ...   0   0   0]\n",
      " [  0   0   0 ...   0   0   0]]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "938dc847e821428c8e0b4fc4fd006be3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train Epoch 7:   0%|          | 0/1200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3a185638ec3145c6aa4db2a6322efb7f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation Epoch 7:   0%|          | 0/429 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:  7\n",
      "Accurarcy:  0.7011655011655011\n",
      "F1 Score:  0.33831113665478957\n",
      "Confusion Matrix:  [[ 34   0   0 ...   0   0   0]\n",
      " [  0  73   0 ...   0   0   0]\n",
      " [  0   0 123 ...   0   0   0]\n",
      " ...\n",
      " [  0   0   0 ...   0   0   0]\n",
      " [  0   0   0 ...   0   0   0]\n",
      " [  0   0   0 ...   0   0   0]]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "98590296815c4c65a60bf961801467cc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Train Epoch 8:   0%|          | 0/1200 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "eval_params(epoch = False, batch = True, learnr = True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run this Section in order to train and evaluate the model on the test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate on the test set\n",
    "train()\n",
    "model.eval()\n",
    "test_loader = DataLoader(test_dataset, batch_size= batches, shuffle=False)\n",
    "\n",
    "y_true = []\n",
    "y_pred = []\n",
    "with torch.no_grad():\n",
    "    for images, labels in tqdm(test_loader):\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device)\n",
    "        outputs = model(images)\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        y_true.extend(labels.cpu().numpy())\n",
    "        y_pred.extend(predicted.cpu().numpy())\n",
    "\n",
    "print(\"Accuracy: \", metrics.accuracy_score(y_true, y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                          precision    recall  f1-score   support\n",
      "\n",
      "                     n02085620-Chihuahua       0.64      0.71      0.67        52\n",
      "              n02085782-Japanese_spaniel       0.82      0.87      0.85        85\n",
      "                   n02085936-Maltese_dog       0.87      0.76      0.81       152\n",
      "                      n02086079-Pekinese       0.73      0.78      0.75        49\n",
      "                      n02086240-Shih-Tzu       0.69      0.72      0.70       114\n",
      "              n02086646-Blenheim_spaniel       0.87      0.86      0.87        88\n",
      "                      n02086910-papillon       0.97      0.91      0.94        96\n",
      "                   n02087046-toy_terrier       0.83      0.67      0.74        72\n",
      "           n02087394-Rhodesian_ridgeback       0.72      0.40      0.52        72\n",
      "                  n02088094-Afghan_hound       0.95      0.95      0.95       139\n",
      "                        n02088238-basset       0.71      0.80      0.75        75\n",
      "                        n02088364-beagle       0.86      0.62      0.72        95\n",
      "                    n02088466-bloodhound       0.93      0.76      0.84        87\n",
      "                      n02088632-bluetick       0.88      0.73      0.80        71\n",
      "       n02089078-black-and-tan_coonhound       0.95      0.61      0.74        59\n",
      "                  n02089867-Walker_hound       0.47      0.66      0.55        53\n",
      "              n02089973-English_foxhound       0.72      0.63      0.67        57\n",
      "                       n02090379-redbone       0.61      0.75      0.67        48\n",
      "                        n02090622-borzoi       0.65      0.78      0.71        51\n",
      "               n02090721-Irish_wolfhound       0.99      0.58      0.73       118\n",
      "             n02091032-Italian_greyhound       0.75      0.67      0.71        82\n",
      "                       n02091134-whippet       0.73      0.69      0.71        87\n",
      "                  n02091244-Ibizan_hound       0.92      0.82      0.87        88\n",
      "            n02091467-Norwegian_elkhound       0.95      0.91      0.93        96\n",
      "                    n02091635-otterhound       0.71      0.92      0.80        51\n",
      "                        n02091831-Saluki       0.84      0.80      0.82       100\n",
      "            n02092002-Scottish_deerhound       0.76      0.86      0.81       132\n",
      "                    n02092339-Weimaraner       0.96      0.82      0.88        60\n",
      "     n02093256-Staffordshire_bullterrier       0.72      0.47      0.57        55\n",
      "n02093428-American_Staffordshire_terrier       0.54      0.69      0.61        64\n",
      "            n02093647-Bedlington_terrier       0.94      0.90      0.92        82\n",
      "                n02093754-Border_terrier       0.89      0.93      0.91        72\n",
      "            n02093859-Kerry_blue_terrier       0.86      0.71      0.78        79\n",
      "                 n02093991-Irish_terrier       0.88      0.75      0.81        69\n",
      "               n02094114-Norfolk_terrier       0.69      0.74      0.71        72\n",
      "               n02094258-Norwich_terrier       0.73      0.73      0.73        85\n",
      "             n02094433-Yorkshire_terrier       0.49      0.72      0.59        64\n",
      "       n02095314-wire-haired_fox_terrier       0.73      0.63      0.68        57\n",
      "              n02095570-Lakeland_terrier       0.74      0.58      0.65        97\n",
      "              n02095889-Sealyham_terrier       0.98      0.86      0.92       102\n",
      "                      n02096051-Airedale       0.81      0.85      0.83       102\n",
      "                         n02096177-cairn       0.79      0.66      0.72        97\n",
      "            n02096294-Australian_terrier       0.71      0.71      0.71        96\n",
      "                n02096437-Dandie_Dinmont       0.97      0.74      0.84        80\n",
      "                   n02096585-Boston_bull       0.90      0.87      0.88        82\n",
      "           n02097047-miniature_schnauzer       0.68      0.74      0.71        54\n",
      "               n02097130-giant_schnauzer       0.51      0.72      0.60        57\n",
      "            n02097209-standard_schnauzer       0.60      0.47      0.53        55\n",
      "                n02097298-Scotch_terrier       0.70      0.86      0.78        58\n",
      "               n02097474-Tibetan_terrier       0.61      0.75      0.67       106\n",
      "                 n02097658-silky_terrier       0.81      0.52      0.63        83\n",
      "   n02098105-soft-coated_wheaten_terrier       0.73      0.57      0.64        56\n",
      "   n02098286-West_Highland_white_terrier       0.70      0.86      0.77        69\n",
      "                         n02098413-Lhasa       0.54      0.62      0.57        86\n",
      "         n02099267-flat-coated_retriever       0.75      0.79      0.77        52\n",
      "        n02099429-curly-coated_retriever       0.81      0.67      0.73        51\n",
      "              n02099601-golden_retriever       0.78      0.78      0.78        50\n",
      "            n02099712-Labrador_retriever       0.58      0.75      0.65        71\n",
      "      n02099849-Chesapeake_Bay_retriever       0.61      0.90      0.72        67\n",
      "   n02100236-German_short-haired_pointer       0.80      0.85      0.82        52\n",
      "                        n02100583-vizsla       0.83      0.80      0.81        54\n",
      "                n02100735-English_setter       0.72      0.70      0.71        61\n",
      "                  n02100877-Irish_setter       0.87      0.82      0.84        55\n",
      "                 n02101006-Gordon_setter       0.87      0.91      0.89        53\n",
      "              n02101388-Brittany_spaniel       0.85      0.75      0.80        52\n",
      "                       n02101556-clumber       0.83      0.86      0.84        50\n",
      "              n02102040-English_springer       0.81      0.92      0.86        59\n",
      "        n02102177-Welsh_springer_spaniel       0.90      0.76      0.83        50\n",
      "                n02102318-cocker_spaniel       0.81      0.71      0.76        59\n",
      "                n02102480-Sussex_spaniel       0.92      0.88      0.90        51\n",
      "           n02102973-Irish_water_spaniel       0.74      0.84      0.79        50\n",
      "                        n02104029-kuvasz       0.48      0.74      0.58        50\n",
      "                    n02104365-schipperke       0.82      0.83      0.83        54\n",
      "                   n02105056-groenendael       0.82      0.92      0.87        50\n",
      "                      n02105162-malinois       0.75      0.78      0.76        50\n",
      "                        n02105251-briard       0.78      0.69      0.73        52\n",
      "                        n02105412-kelpie       0.57      0.81      0.67        53\n",
      "                      n02105505-komondor       0.92      0.89      0.91        54\n",
      "          n02105641-Old_English_sheepdog       0.82      0.81      0.82        69\n",
      "             n02105855-Shetland_sheepdog       0.93      0.49      0.64        57\n",
      "                        n02106030-collie       0.55      0.68      0.61        53\n",
      "                 n02106166-Border_collie       0.69      0.74      0.71        50\n",
      "          n02106382-Bouvier_des_Flandres       0.70      0.64      0.67        50\n",
      "                    n02106550-Rottweiler       0.67      0.98      0.80        52\n",
      "               n02106662-German_shepherd       0.75      0.85      0.79        52\n",
      "                      n02107142-Doberman       0.75      0.72      0.73        50\n",
      "            n02107312-miniature_pinscher       0.85      0.71      0.77        84\n",
      "    n02107574-Greater_Swiss_Mountain_dog       0.59      0.79      0.68        68\n",
      "          n02107683-Bernese_mountain_dog       0.92      0.92      0.92       118\n",
      "                   n02107908-Appenzeller       0.49      0.65      0.56        51\n",
      "                   n02108000-EntleBucher       0.90      0.62      0.73       102\n",
      "                         n02108089-boxer       0.60      0.67      0.63        51\n",
      "                  n02108422-bull_mastiff       0.78      0.77      0.77        56\n",
      "               n02108551-Tibetan_mastiff       0.62      0.77      0.69        52\n",
      "                n02108915-French_bulldog       0.80      0.86      0.83        59\n",
      "                    n02109047-Great_Dane       0.63      0.64      0.64        56\n",
      "                 n02109525-Saint_Bernard       0.93      0.96      0.94        70\n",
      "                    n02109961-Eskimo_dog       0.30      0.58      0.39        50\n",
      "                      n02110063-malamute       0.56      0.69      0.62        78\n",
      "                n02110185-Siberian_husky       0.58      0.15      0.24        92\n",
      "                 n02110627-affenpinscher       0.87      0.78      0.82        50\n",
      "                       n02110806-basenji       0.89      0.79      0.83       109\n",
      "                           n02110958-pug       0.87      0.84      0.85       100\n",
      "                      n02111129-Leonberg       0.89      0.91      0.90       110\n",
      "                  n02111277-Newfoundland       0.65      0.77      0.71        95\n",
      "                n02111500-Great_Pyrenees       0.73      0.58      0.64       113\n",
      "                       n02111889-Samoyed       0.69      0.98      0.81       118\n",
      "                    n02112018-Pomeranian       0.98      0.79      0.87       119\n",
      "                          n02112137-chow       0.87      0.99      0.93        96\n",
      "                      n02112350-keeshond       0.90      0.95      0.92        58\n",
      "             n02112706-Brabancon_griffon       0.93      0.81      0.87        53\n",
      "                      n02113023-Pembroke       0.84      0.77      0.80        81\n",
      "                      n02113186-Cardigan       0.67      0.75      0.71        55\n",
      "                    n02113624-toy_poodle       0.58      0.43      0.49        51\n",
      "              n02113712-miniature_poodle       0.35      0.53      0.42        55\n",
      "               n02113799-standard_poodle       0.60      0.66      0.63        59\n",
      "              n02113978-Mexican_hairless       0.92      0.82      0.87        55\n",
      "                         n02115641-dingo       0.67      0.80      0.73        56\n",
      "                         n02115913-dhole       0.96      0.86      0.91        50\n",
      "           n02116738-African_hunting_dog       0.80      0.94      0.87        69\n",
      "\n",
      "                                accuracy                           0.76      8580\n",
      "                               macro avg       0.76      0.75      0.75      8580\n",
      "                            weighted avg       0.78      0.76      0.76      8580\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(metrics.classification_report(y_true, y_pred, target_names=list(categories.keys())))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
